"""

"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models import encoders, decoders
from disvae.models.anneal import *
from disvae.models.losses import BtcvaeLoss, get_loss_s, _kl_normal_loss, _reconstruction_loss
from exps.models import *
from exps.visualize import *
from exps.patterns import data_generator
import numpy as np
from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=5e-4,
    epochs=441,
    beta=10,
    dimension=3,
    img_id=0,
    random_seed=210,
    file=__file__.split('/')[-1]
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()

wandb.init(project="exps", config=args, )
config = args
set_seed(config.random_seed)


def train(dl, model, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    # opt.add_param_group({'params': decoder_list[0].parameters(),
    #                      'lr': config.learning_rate})
    # opt.add_param_group({'params': decoder_list[1].parameters(),
    #                      'lr': config.learning_rate * 0.1})
    # opt.add_param_group({'params': decoder_list[2].parameters(),
    #                      'lr': config.learning_rate * 0.02})
    mean_img = dataset.imgs.mean(0).cuda()
    for e in trange(epochs):
        storer = defaultdict(list)
        for img, label in dl:
            bs = len(img)
            img = img.view(-1, 1, 64, 64).cuda()
            latent_dist = encoder(img)
            latent_sample = reparameterize(*latent_dist)

            loss = config.beta * _kl_normal_loss(*latent_dist, storer)

            recon_batch = None
            for i, decoder in enumerate(decoder_list):
                p = decoder(latent_sample[:, i:i + 1])
                if recon_batch is None:
                    recon_batch = p
                else:
                    recon_batch = recon_batch - p

            recon_loss = _reconstruction_loss(img, recon_batch, distribution='gaussian')
            # recon_loss = sum(recon_list)#(recon_list[0]*recon_list[1]+recon_list[1]+1)*recon_list[2]
            loss = loss + recon_loss
            storer['recon_loss'].append(recon_loss.item())

            opt.zero_grad()
            loss.backward()
            opt.step()

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if (e + 1) % 40 == 0:
            storer['recon'] = wandb.Image(recon_batch[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
            _, index = kl_loss.sort(descending=True)

            preds, targets = get_preds(model, dl)
            fig = gt_vs_latent(preds[:, index[:3]].cpu(), targets.cpu())
            storer['correlated'] = wandb.Image(fig)
            plt.close(fig)

            dim_len, traversal_len, r = 3, 7, 3
            fig, axes = plt.subplots(dim_len, traversal_len, squeeze=False,
                                     figsize=(traversal_len, dim_len))
            axes = axes.reshape(dim_len, traversal_len)
            plt.tight_layout(0.1)
            mu = torch.zeros(1, 1)
            for i, dim in enumerate(range(dim_len)):
                base_latents = mu
                linear_traversal = torch.linspace(-r, r, traversal_len)
                traversals = traversal_latents(base_latents, linear_traversal, 0).cuda()
                if i == 0:
                    recon_batch = decoder_list[i](traversals)
                else:
                    recon_batch = model[i - 1](traversals)

                plot_bar(axes[i], recon_batch[:, 0].cpu().data)
            storer['traversal_delta'] = wandb.Image(fig)

            plt.close(fig)

        storer['epoch'] = e
        wandb.log(storer, sync=True)
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')
# img = data_generator.rotate(img,45)
dataset = data_generator.Translation(img)

# dataset.imgs = dataset.imgs[:40]

epochs = config.epochs
dim = config.dimension

dl = DataLoader(dataset, pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

encoder = encoders.EncoderBurgess((1, 64, 64), dim)
decoder_list = [decoders.DecoderBurgess((1, 64, 64), 1) for _ in range(dim)]
model = torch.nn.Sequential(*decoder_list[1:])
model.add_module('encoder', encoder)
model.add_module('decoder', decoder_list[0])
model.train()
model.cuda()

storer = train(dl, model, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
model.eval()
model.cpu()

preds, targets = get_preds(model, dl)
points = preds[:, index[:3]].cpu()
storer['point_cloud'] = wandb.Object3D(points.numpy())

# fig = plt_sample_traversal(None, model, 7, index, r=2)
# storer['traversal'] = wandb.Image(fig)


dim_len, traversal_len, r = 3, 20, 3
fig, axes = plt.subplots(dim_len, traversal_len, squeeze=False,
                         figsize=(traversal_len, dim_len))
axes = axes.reshape(dim_len, traversal_len)
plt.tight_layout(0.1)
mu = torch.zeros(1, dim_len)
for j, dim in enumerate(range(dim_len)):
    base_latents = mu
    linear_traversal = torch.linspace(-r, r, traversal_len)
    traversals = traversal_latents(base_latents, linear_traversal, dim)

    recon_list = []
    for i, decoder in enumerate(decoder_list):
        p = decoder(traversals[:, i:i + 1])
        recon_list.append(p)
    recon_batch = recon_list[0] - sum(recon_list[1:])

    plot_bar(axes[j], recon_batch[:, 0].cpu().data)
storer['traversal'] = wandb.Image(fig)
wandb.log(storer)
plt.show()
